import sys
import os, sys
import os
import urllib
from tqdm import tqdm
from math import sqrt

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange
from image_synthesis.modeling.codecs.base_codec import BaseCodec
from image_synthesis.modeling.modules.dalle.utils import unmap_pixels, map_pixels
# from dall_e import unmap_pixels, map_pixels

# constants

ENCODER_PATH = 'https://cdn.openai.com/dall-e/encoder.pkl'
DECODER_PATH = 'https://cdn.openai.com/dall-e/decoder.pkl'


# helpers methods

def load_model(path):
    with open(path, 'rb') as f:
        stat_dict = torch.load(f, map_location = torch.device('cpu'))
        return stat_dict # torch.load(f, map_location = torch.device('cpu'))

def download(url, root=[os.path.expanduser("~/.cache/image-synthesis/openai_dvae"),
                        'DATASET/t-qiankunliu-expresource/openai_dvae'
                        ]):

    filename = os.path.basename(url)

    # try to find downloaded models in some directories
    for root_ in root:
        download_target = os.path.join(root_, filename)
        if os.path.exists(download_target) and not os.path.isfile(download_target):
            raise RuntimeError(f"{download_target} exists and is not a regular file")

        if os.path.isfile(download_target):
            return download_target
    
    root = root[0]
    download_target = os.path.join(root, filename)
    download_target_tmp = os.path.join(root, f'tmp.{filename}')
    os.makedirs(root, exist_ok=True)

    with urllib.request.urlopen(url) as source, open(download_target_tmp, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    os.rename(download_target_tmp, download_target)
    return download_target

# adapter class

class OpenAIDiscreteVAE(BaseCodec):
    def __init__(
            self, 
            trainable=False,
            token_shape=[32,32],
        ):
        super().__init__()
        try:
            import dall_e
        except:
            print(f'you need to "pip install git+https://github.com/openai/DALL-E.git" before you can use the pretrained OpenAI Discrete VAE')
            sys.exit()
        self.enc = load_model(download(ENCODER_PATH))
        self.dec = load_model(download(DECODER_PATH))

        self.num_layers = 3
        self.image_size = 256
        self.num_tokens = 8192
    
        self.trainable = trainable
        self.token_shape = token_shape
        self._set_trainable()


    def half(self):
        """
        overwrite this function
        """
        from dall_e.utils import Conv2d
        for n, m in self.named_modules():
            if isinstance(m, Conv2d) and m.use_float16:
                print(n)
                m._apply(lambda t: t.half() if t.is_floating_point() else t)

        return self

    @property
    def device(self):
        # import pdb; pdb.set_trace()
        return self.enc.blocks[0].w.device

    def preprocess(self, imgs):
        """
        imgs: B x C x H x W, in the range 0-255
        """
        imgs = imgs.div(255) # map to 0 - 1
        return map_pixels(imgs)   
    
    def postprocess(self, imgs):
        """
        imgs: B x C x H x W, in the range 0-1
        """
        imgs = imgs * 255
        return imgs

    def get_tokens(self, imgs, mask=None, enc_with_mask=True, **kwargs):
        imgs = self.preprocess(imgs)
        z_logits = self.enc(imgs)
        z = torch.argmax(z_logits, dim=1)
        if mask is not None:
            if enc_with_mask:
                z_logits = self.enc(imgs * mask.to(imgs))
                z_mask = torch.argmax(z_logits, dim=1)
            else:
                z_mask = z.clone()
            # mask = F.interpolate(mask.float(), size=z_mask.shape[-2:]).to(torch.bool)
            token_type = get_token_type(mask, self.token_shape) # B x 1 x H x W
            mask = ~(token_type != 0)
            output = {
                'token': z_mask.flatten(1),
                'target': z.flatten(1),
                'mask': mask.flatten(1),
                'token_type': token_type.flatten(1),
            }
        else:
            output = {'token': rearrange(z, 'b h w -> b (h w)')}
        return output

    def get_number_of_tokens(self):
        return self.num_tokens

    def encode(self, imgs):
        z_logits = self.enc(self.preprocess(imgs))
        return z_logits

    def decode(self, img_seq):
        b, n = img_seq.shape
        # if self.token_shape is not None:
        #     img_seq = img_seq.view(b, self.token_shape[0], self.token_shape[1])
        # else:
        #     img_seq = rearrange(img_seq, 'b (h w) -> b h w', h = int(sqrt(n)))
        img_seq = rearrange(img_seq, 'b (h w) -> b h w', h = int(sqrt(n)))

        z = F.one_hot(img_seq, num_classes = self.num_tokens)
        z = rearrange(z, 'b h w c -> b c h w').float()
        x_stats = self.dec(z).float()
        x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
        x_rec = self.postprocess(x_rec)
        return x_rec

    def get_rec_loss(self, img, rec):
        img = self.preprocess(img)
        rec = self.preprocess(rec)
        return nn.MSELoss()(img, rec).mean()


    def forward(self, img):
        raise NotImplementedError
